/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_options.h"

#include "infra/base/assertion.h"
#include "framework/infra/log/log.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/ndk_proxy.h"

namespace hiai {
Status HIAI_NDK_HiAIOptions_SetInputTensorShapes(OH_NNCompilation* compilation,
    NN_TensorDesc* inputTensorDescs[], size_t shapeCount)
{
    auto setInputTensorShapesFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetInputTensorShapes)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetInputTensorShapes"));
    HIAI_EXPECT_NOT_NULL(setInputTensorShapesFunc);
    OH_NN_ReturnCode retCode = setInputTensorShapesFunc(compilation, inputTensorDescs, shapeCount);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetFormatMode(OH_NNCompilation* compilation, HiAI_FormatMode formatMode)
{
    auto setFormatModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetFormatMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetFormatMode"));
    HIAI_EXPECT_NOT_NULL(setFormatModeFunc);
    OH_NN_ReturnCode retCode = setFormatModeFunc(compilation, formatMode);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetDynamicShapeStatus(OH_NNCompilation* compilation, HiAI_DynamicShapeStatus status)
{
    auto setDynamicShapeStatusFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetDynamicShapeStatus)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetDynamicShapeStatus"));
    HIAI_EXPECT_NOT_NULL(setDynamicShapeStatusFunc);
    OH_NN_ReturnCode retCode = setDynamicShapeStatusFunc(compilation, status);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetDynamicShapeMaxCache(OH_NNCompilation* compilation, size_t maxCacheCount)
{
    auto setDynamicShapeMaxCacheFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetDynamicShapeMaxCache)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetDynamicShapeMaxCache"));
    HIAI_EXPECT_NOT_NULL(setDynamicShapeMaxCacheFunc);
    OH_NN_ReturnCode retCode = setDynamicShapeMaxCacheFunc(compilation, maxCacheCount);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetDynamicShapeCacheMode(OH_NNCompilation* compilation, HiAI_DynamicShapeCacheMode mode)
{
    auto setDynamicShapeCacheModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetDynamicShapeCacheMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetDynamicShapeCacheMode"));
    HIAI_EXPECT_NOT_NULL(setDynamicShapeCacheModeFunc);
    OH_NN_ReturnCode retCode = setDynamicShapeCacheModeFunc(compilation, mode);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetOperatorDeviceOrder(OH_NNCompilation* compilation,
    const char* operatorName, HiAI_ExecuteDevice* executeDevices, size_t deviceCount)
{
    auto setOperatorDeviceOrderFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetOperatorDeviceOrder)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetOperatorDeviceOrder"));
    HIAI_EXPECT_NOT_NULL(setOperatorDeviceOrderFunc);
    OH_NN_ReturnCode retCode = setOperatorDeviceOrderFunc(compilation, operatorName, executeDevices, deviceCount);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetModelDeviceOrder(OH_NNCompilation* compilation,
    HiAI_ExecuteDevice* executeDevices, size_t deviceCount)
{
    auto setModelDeviceOrderFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetModelDeviceOrder)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetModelDeviceOrder"));
    HIAI_EXPECT_NOT_NULL(setModelDeviceOrderFunc);
    OH_NN_ReturnCode retCode = setModelDeviceOrderFunc(compilation, executeDevices, deviceCount);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetFallbackMode(OH_NNCompilation* compilation, HiAI_FallbackMode fallbackMode)
{
    auto setFallbackModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetFallbackMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetFallbackMode"));
    HIAI_EXPECT_NOT_NULL(setFallbackModeFunc);
    OH_NN_ReturnCode retCode = setFallbackModeFunc(compilation, fallbackMode);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetDeviceMemoryReusePlan(OH_NNCompilation* compilation,
    HiAI_DeviceMemoryReusePlan deviceMemoryReusePlan)
{
    auto setDeviceMemoryReusePlanFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetDeviceMemoryReusePlan)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetDeviceMemoryReusePlan"));
    HIAI_EXPECT_NOT_NULL(setDeviceMemoryReusePlanFunc);
    OH_NN_ReturnCode retCode = setDeviceMemoryReusePlanFunc(compilation, deviceMemoryReusePlan);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetTuningStrategy(OH_NNCompilation* compilation, HiAI_TuningStrategy tuningStrategy)
{
    auto setTuningStrategyFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetTuningStrategy)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetTuningStrategy"));
    HIAI_EXPECT_NOT_NULL(setTuningStrategyFunc);
    OH_NN_ReturnCode retCode = setTuningStrategyFunc(compilation, tuningStrategy);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetQuantConfig(OH_NNCompilation* compilation, void* data, size_t size)
{
    auto setQuantConfigFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetQuantConfig)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetQuantConfig"));
    HIAI_EXPECT_NOT_NULL(setQuantConfigFunc);
    OH_NN_ReturnCode retCode = setQuantConfigFunc(compilation, data, size);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetTuningMode(OH_NNCompilation* compilation, HiAI_TuningMode tuningMode)
{
    auto setTuningModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetTuningMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetTuningMode"));
    HIAI_EXPECT_NOT_NULL(setTuningModeFunc);
    OH_NN_ReturnCode retCode = setTuningModeFunc(compilation, tuningMode);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetTuningCacheDir(OH_NNCompilation* compilation, const char* cacheDir)
{
    auto setTuningCacheDirFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetTuningCacheDir)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetTuningCacheDir"));
    HIAI_EXPECT_NOT_NULL(setTuningCacheDirFunc);
    OH_NN_ReturnCode retCode = setTuningCacheDirFunc(compilation, cacheDir);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetBandMode(OH_NNCompilation* nnCompilation, HiAI_BandMode bandMode)
{
    auto setBandModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetBandMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetBandMode"));
    HIAI_EXPECT_NOT_NULL(setBandModeFunc);

    OH_NN_ReturnCode retCode = setBandModeFunc(nnCompilation, bandMode);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

HiAI_BandMode HIAI_NDK_HiAIOptions_GetBandMode(const OH_NNCompilation* nnCompilation)
{
    auto getBandModeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_GetBandMode)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_GetBandMode"));
    return getBandModeFunc == nullptr ? HIAI_BANDMODE_UNSET : getBandModeFunc(nnCompilation);
}

Status HIAI_NDK_HiAIOptions_SetAsyncModeEnable(OH_NNCompilation* nnCompilation, bool isEnable)
{
    auto setAsyncModeEnableFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetAsyncModeEnable)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetAsyncModeEnable"));
    HIAI_EXPECT_NOT_NULL(setAsyncModeEnableFunc);

    OH_NN_ReturnCode retCode = setAsyncModeEnableFunc(nnCompilation, isEnable);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

bool HIAI_NDK_HiAIOptions_GetAsyncModeEnable(const OH_NNCompilation* nnCompilation)
{
    auto getAsyncModeEnableFunc = reinterpret_cast<decltype(HMS_HiAIOptions_GetAsyncModeEnable)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_GetAsyncModeEnable"));
    return getAsyncModeEnableFunc == nullptr ? false : getAsyncModeEnableFunc(nnCompilation);
}

Status HIAI_NDK_HiAIOptions_SetCustomOpPath(OH_NNCompilation* compilation, const char* customPath)
{
    auto setCustomOpFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetCustomOP)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetCustomOP"));
    HIAI_EXPECT_NOT_NULL(setCustomOpFunc);
    OH_NN_ReturnCode retCode = setCustomOpFunc(compilation, customPath);
    return retCode == OH_NN_SUCCESS ? SUCCESS : FAILURE;
}

Status HIAI_NDK_HiAIOptions_SetAllocate(OH_NNCompilation* compilation, onAllocate allocateFunc)
{
    auto setAllocateFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetAllocate)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetAllocate"));
    HIAI_EXPECT_NOT_NULL(setAllocateFunc);
    HIAI_EXPECT_TRUE(setAllocateFunc(compilation, allocateFunc) == OH_NN_SUCCESS);
    return SUCCESS;
}

Status HIAI_NDK_HiAIOptions_SetFree(OH_NNCompilation* compilation, onFree freeFunc)
{
    auto setFreeFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetFree)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetFree"));
    HIAI_EXPECT_NOT_NULL(setFreeFunc);
    HIAI_EXPECT_TRUE(setFreeFunc(compilation, freeFunc) == OH_NN_SUCCESS);
    return SUCCESS;
}

Status HIAI_NDK_HiAIOptions_SetUserData(OH_NNCompilation* compilation, void* userData)
{
    auto setUserDataFunc = reinterpret_cast<decltype(HMS_HiAIOptions_SetUserData)*>(
        NDKProxy::GetSymbol("HMS_HiAIOptions_SetUserData"));
    HIAI_EXPECT_NOT_NULL(setUserDataFunc);
    HIAI_EXPECT_TRUE(setUserDataFunc(compilation, userData) == OH_NN_SUCCESS);
    return SUCCESS;
}
}
